#pragma once

#include <memory>
#include <vector>

#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/testing/file_check.h>

namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;

#define IS_NODE(T, node)                        \
  {                                             \
    auto* node_ = dynamic_cast<const T*>(node); \
    ASSERT_NE(nullptr, node_);                  \
  }

#define IS_NODE_WITH_NAME(T, node, name)     \
  auto* name = dynamic_cast<const T*>(node); \
  ASSERT_NE(nullptr, name);

#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type)        \
  const T* name = nullptr;                                     \
  {                                                            \
    auto* node_ = dynamic_cast<const Cast*>(node);             \
    ASSERT_NE(nullptr, node_);                                 \
    ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
    name = dynamic_cast<const T*>(node_->src_value());         \
  }                                                            \
  ASSERT_NE(nullptr, name);

#define IS_IMM_WITH_VAL(T, node, val)                \
  {                                                  \
    auto* node_ = dynamic_cast<const T##Imm*>(node); \
    ASSERT_NE(nullptr, node_);                       \
    ASSERT_EQ(node_->value(), val);                  \
  }

#define IS_VAR_WITH_NAME(node, name)              \
  {                                               \
    auto* node_ = dynamic_cast<const Var*>(node); \
    ASSERT_NE(nullptr, node_);                    \
    ASSERT_EQ(node_->name_hint(), name);          \
  }

#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
  const T* name = nullptr;                     \
  {                                            \
    name = dynamic_cast<const T*>(node);       \
    ASSERT_NE(nullptr, name);                  \
    IS_VAR_WITH_NAME(name->lhs(), v1);         \
    IS_VAR_WITH_NAME(name->rhs(), v2);         \
  }

#define IS_BINOP_W_CONST(T, node, name, v, c) \
  const T* name = nullptr;                    \
  {                                           \
    name = dynamic_cast<const T*>(node);      \
    ASSERT_NE(nullptr, name);                 \
    IS_VAR_WITH_NAME(name->lhs(), v);         \
    IS_IMM_WITH_VAL(Int, name->rhs(), c);     \
  }

#define IS_RAND(node)                                    \
  {                                                      \
    auto* node_ = dynamic_cast<const Intrinsics*>(node); \
    ASSERT_NE(nullptr, node_);                           \
    ASSERT_EQ(node_->op_type(), kRand);                  \
  }

} // namespace jit
} // namespace torch
